import numpy as np
import pandas as pd
import statsmodels.formula.api as sm

from cat_rfs.code.utils.tools import save_stat
from cat_rfs.code.utils.table import Table


def print_table2(df, output='console'):
    """
    Print beta_determinants table and save some stats.
    """
    df = df.copy()

    # %% Calculate peril weights
    perils = list(df.loc[~df['perils'].str.contains(','), 'perils'].unique())
    perils = ['el_' + s.replace(' ', '_').lower() for s in perils]
    df['totel'] = df[perils].sum(axis='columns')

    for n, peril in enumerate(perils):
        df[f'{peril}_weight'] = df[peril] / df['totel']
        df[f'{peril}_amt'] = df[f'{peril}_weight'] * df['size']
        df[f'{peril}_amt_group'] = df.groupby(['date'])[f'{peril}_amt'].transform('sum')
        df[f'{peril}_weighted'] = df[f'{peril}_weight'] * df[f'{peril}_amt_group']

    df['weight_peril'] = df[[f'{peril}_weighted' for peril in perils]].sum(axis='columns') / df[[f'{peril}_amt_group' for peril in perils]].sum(axis='columns')
    indicators = ' + '.join(list(df.columns[df.columns.str.endswith('_I')]))
    # %% End

    df['lnsize'] = np.log(df['size'])
    df['sheet_spread'] = df['sheet_price'] * 100

    col0 = ['$ap$', '', '$el$', '', '$ep$', '', '$ln(Size)$', '', '$dm$', '', '$ref$', '', '$w_c$', '',
            '$el \\times w_c$', '', '$\\hat{\\beta}_{loss}$', '']

    ols = sm.ols(f'beta_sheet ~ expected_loss + attachment_prob + exhaustion_prob + lnsize + sheet_spread + bmk_yield + {indicators} + C(date)', data=df).fit(cov_kwds={'groups': np.array(df[['date', 'CUSIP9']])})

    col1 = [ols.params['attachment_prob'],
            ols.bse['attachment_prob'],
            ols.params['expected_loss'],
            ols.bse['expected_loss'],
            ols.params['exhaustion_prob'],
            ols.bse['exhaustion_prob'],
            ols.params['lnsize'],
            ols.bse['lnsize'],
            ols.params['sheet_spread'],
            ols.bse['sheet_spread'],
            ols.params['bmk_yield'],
            ols.bse['bmk_yield'],
            '',
            '',
            '',
            '',
            '',
            '']

    R2_row = ['$R^2$', ols.rsquared]
    N_row = ['N', str(int(ols.nobs))]

    save_stat(ols.rsquared * 100, 'beta_det_r2_c1', formatting='%.1f', output=output)

    ols = sm.ols('beta_sheet ~ beta_loss', data=df).fit(cov_kwds={'groups': np.array(df[['date', 'CUSIP9']])})

    save_stat(ols.params['beta_loss'], 'beta_sheet_loss_coef', formatting='%.2f', output=output)

    col2 = ['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '',
            ols.params['beta_loss'], ols.bse['beta_loss']]
    R2_row += [ols.rsquared]
    N_row += [str(int(ols.nobs))]

    ols = sm.ols('beta_sheet ~ expected_loss + weight_peril', data=df
                 ).fit(cov_kwds={'groups': np.array(df[['date', 'CUSIP9']])})
    col3 = ['', '', ols.params['expected_loss'], ols.bse['expected_loss'], '', '', '', '', '', '', '', '',
            ols.params['weight_peril'], ols.bse['weight_peril'], '', '', '', '']
    R2_row += [ols.rsquared]
    N_row += [str(int(ols.nobs))]

    save_stat(ols.rsquared * 100, 'beta_det_r2_c3', formatting='%.1f', output=output)

    ols = sm.ols('beta_sheet ~ expected_loss * weight_peril', data=df
                 ).fit(cov_kwds={'groups': np.array(df[['date', 'CUSIP9']])})
    col4 = ['', '', ols.params['expected_loss'], ols.bse['expected_loss'], '', '', '', '', '', '', '', '',
            ols.params['weight_peril'], ols.bse['weight_peril'],
            ols.params['expected_loss:weight_peril'], ols.bse['expected_loss:weight_peril'], '', '']
    R2_row += [ols.rsquared]
    N_row += [str(int(ols.nobs))]

    save_stat(ols.rsquared * 100, 'beta_det_r2_c4', formatting='%.1f', output=output)

    ols = sm.ols('beta_sheet ~ expected_loss : weight_peril', data=df
                 ).fit(cov_kwds={'groups': np.array(df[['date', 'CUSIP9']])})
    col5 = ['', '', '', '', '', '', '', '', '', '', '', '', '', '',
            ols.params['expected_loss:weight_peril'], ols.bse['expected_loss:weight_peril'], '', '']
    R2_row += [ols.rsquared]
    N_row += [str(int(ols.nobs))]

    save_stat(ols.rsquared * 100, 'beta_det_r2_c5', formatting='%.1f', output=output)

    R2_col = []
    for var in ['attachment_prob', 'expected_loss', 'exhaustion_prob', 'lnsize', 'sheet_spread', 'bmk_yield',
                'weight_peril', 'expected_loss:weight_peril', 'beta_loss']:
        R2_col += [sm.ols(f'beta_sheet ~ {var}', data=df).fit().rsquared, '']
        save_stat(R2_col[-2] * 100, f'beta_det_ur2_{var}', formatting='%.1f', output=output)

    R2_row += ['']
    N_row += ['']

    R2_fe = [sm.ols(f'beta_sheet ~ {indicators}', data=df).fit().rsquared,
             sm.ols(f'beta_sheet ~ C(date)', data=df).fit().rsquared]

    save_stat(R2_fe[0] * 100, f'beta_det_ur2_perils', formatting='%.1f', output=output)
    save_stat(R2_fe[1] * 100, f'beta_det_ur2_date', formatting='%.1f', output=output)

    data = pd.DataFrame([col0, col1, col2, col3, col4, col5]).T
    data[6] = R2_col

    data_fe = pd.DataFrame([['Peril FE', 'Time FE'], ['Yes', 'Yes'], ['No', 'No'], ['No', 'No'],
                            ['No', 'No'], ['No', 'No'], R2_fe]).T

    data2 = pd.DataFrame([R2_row, N_row])

    save_stat(df[['expected_loss', 'attachment_prob']].corr().iloc[0, 1], 'el_ap_corr', formatting='%.3f',
              output=output)

    save_stat(df[['expected_loss', 'exhaustion_prob']].corr().iloc[0, 1], 'el_ep_corr', formatting='%.3f',
              output=output)

    T = Table(7, width='auto', justs=['l', 'd', 'd', 'd', 'd', 'd', 'd'],
              caption='Determinants of $\\hat{\\beta}$', label='beta_determinants')

    # dof should subtract the number of regressors but this does not affect number of stars in any case.
    T.add_panel(data, headers=['', '(1)', '(2)', '(3)', '(4)', '(5)', 'Univariate $R^2$'], formatting='%.2f',
                regression_specs={'stars': (0.1, 0.05, 0.01), 'dof': int(ols.nobs), 'skip_cols': [7]})

    T.add_panel(data_fe)

    T.add_panel(data2)

    if output == 'paper':
        T.print_table('cat_rfs/output/tables/table2.tex')
    else:
        T.print_table()

    pass
